import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import mask_loss


class CrossEntropy(nn.Module):
    def __init__(self, eps=0.):
        super().__init__()
        self.eps = eps

    @mask_loss
    def forward(self, logits, targets):
        if torch.__version__.startswith('1.10'):
            return F.cross_entropy(logits, targets, label_smoothing=self.eps)
        return F.cross_entropy(logits, targets)
